#!/usr/bin/env python

import roslib; roslib.load_manifest('bwi_apps')
import rospy

import yaml, os
from PIL import Image

from tf.transformations import quaternion_from_euler
from tf import TransformBroadcaster
from nav_msgs.srv import GetMap
from nav_msgs.srv import GetMapResponse
from nav_msgs.msg import OccupancyGrid
from nav_msgs.msg import MapMetaData
from bwi_msgs.msg import MultiLevelMapLink
from bwi_msgs.msg import MultiLevelMapData
from bwi_msgs.msg import LevelMetaData

import bwi_utils.utils
utils = bwi_utils

class MapServer:
  def __init__(self):
    rospy.init_node('bwi_map_server')

    #get parameters
    self.global_frame_id = rospy.get_param('~global_frame_id', '/map')
    self.publish_rate = rospy.get_param('~publish_rate', 5)
    try:
      self.map_list = rospy.get_param('~maps')
    except:
      rospy.logfatal("Please provide map list (~maps)")
      return

    try:
      self.map_links = rospy.get_param('~links')
    except:
      rospy.logwarn("No links provided between different map levels (~links). Planning from one level to another will fail")
      self.map_links = list()

    # process maps
    self.map_response = dict()
    self.service_server = dict()
    self.metadata_publisher = dict()
    self.map_publisher = dict()
    self.map_origin_position = dict()
    self.map_origin_orientation = dict()

    out_meta_data_msg = MultiLevelMapData()

    for map_name, map_info in self.map_list.items():

      yaml_file = map_info['yaml'];
      origin = map_info['origin'];
      if len(origin) == 3:
        self.map_origin_position[map_name] = (origin[0], origin[1], 0)
        self.map_origin_orientation[map_name] = quaternion_from_euler(0,0,origin[2])
      elif len(origin) == 6:
        self.map_origin_position[map_name] = (origin[0], origin[1], origin[2])
        self.map_origin_orientation[map_name] = quaternion_from_euler(origin[3], origin[4], origin[5])
      elif len(origin) == 7:
        self.map_origin_position[map_name] = (origin[0], origin[1], origin[2])
        self.map_origin_orientation[map_name] = (origin[3], origin[4], origin[5], origin[6])
      else:
        raise TypeError("Unable to parse origin tag")

      try:
        map_info = yaml.load(open(yaml_file, 'r'))
      except:
        rospy.logfatal("Unable to load yaml file for map: %s" %yaml_file)
        return

      resolution = map_info.get('resolution')
      origin = map_info.get('origin')
      negate = map_info.get('negate')
      occupied_thresh = map_info.get('occupied_thresh')
      free_thresh = map_info.get('free_thresh')

      image_file = map_info.get('image')
      if image_file[0] != '/': 
        yaml_file_dir = os.path.dirname(os.path.realpath(yaml_file))
        image_file = yaml_file_dir + '/' + image_file

      self.map_response[map_name] = self.loadMapFromFile(image_file, resolution,
          negate, occupied_thresh, free_thresh, origin)
      self.map_response[map_name].map.info.map_load_time = rospy.Time.now()
      self.map_response[map_name].map.header.frame_id = '/' + map_name + '/map'
      self.map_response[map_name].map.header.stamp = rospy.Time.now()

      # Publish map service
      self.service_server[map_name] = rospy.Service(map_name + '/static_map', 
          GetMap, self.levelCallback)

      # Publish latched map messages
      self.metadata_publisher[map_name] = rospy.Publisher(map_name + '/map_metadata',
          MapMetaData, latch=True)
      self.metadata_publisher[map_name].publish(self.map_response[map_name].map.info)
      self.map_publisher[map_name] = rospy.Publisher(map_name + '/map',
          OccupancyGrid, latch=True)
      self.map_publisher[map_name].publish(self.map_response[map_name].map)

      # Construct the level meta data message
      level_data = LevelMetaData()
      level_data.level_id = map_name
      level_data.info = self.map_response[map_name].map.info

    #process links
    for map_link in self.map_links:
      link = MultiLevelMapLink()

      from_point = map_link['from_point']
      link.from_point.level_name = from_point['level_name']
      link.from_point.frame_id = self.map_response[link.from_point.level_name].map.header.frame_id
      link.from_point.point.x = from_point['point'][0]
      link.from_point.point.y = from_point['point'][1]
      link.from_point.point.z = from_point['point'][2]

      to_point = map_link['to_point']
      link.to_point.level_name = to_point['level_name']
      link.to_point.frame_id = self.map_response[link.to_point.level_name].map.header.frame_id
      link.to_point.point.x = to_point['point'][0]
      link.to_point.point.y = to_point['point'][1]
      link.to_point.point.z = to_point['point'][2]

      out_msg.links.append(link)
      out_meta_data_msg.links.append(link)

    # publish MultiLevelMap
    out_msg.header.frame_id = self.global_frame_id
    out_msg.header.stamp = rospy.Time.now()
    self.master_pub = rospy.Publisher("map", MultiLevelMap, latch=True)
    self.master_pub.publish(out_msg)

    # publish MultiLevelMapData
    out_meta_data_msg.header.frame_id = self.global_frame_id
    out_meta_data_msg.header.stamp = rospy.Time.now()
    self.metadata_pub = rospy.Publisher("map_metadata", MultiLevelMapData, latch=True)
    self.metadata_pub.publish(out_meta_data_msg)

    # publish tf tree from global frame of reference
    rate = rospy.Rate(self.publish_rate)
    while not rospy.is_shutdown():
      for map_name, map_response in self.map_response.items():
        br = TransformBroadcaster()
        br.sendTransform(self.map_origin_position[map_name],
                         self.map_origin_orientation[map_name],
                         rospy.Time.now(),
                         map_response.map.header.frame_id,
                         self.global_frame_id)

      rate.sleep();

  def levelCallback(self, req):
    service = req._connection_header['service']
    for map_name, map_response in self.map_response.items():
      if service == rospy.resolve_name(map_name + '/static_map'):
        return map_response

  def loadMapFromFile(self, image_file, resolution, negate, occupied_thresh,
      free_thresh, origin):

    resp = GetMapResponse()

    image = Image.open(image_file)
    pix = image.load()

    image_size = image.size
    resp.map.info.width = image_size[0]
    resp.map.info.height = image_size[1]
    resp.map.info.resolution = resolution

    resp.map.info.origin.position.x = origin[0]
    resp.map.info.origin.position.y = origin[1]
    resp.map.info.origin.position.z = 0
    q = quaternion_from_euler(0,0,origin[2])
    resp.map.info.origin.orientation.x = q[0]
    resp.map.info.origin.orientation.y = q[1]
    resp.map.info.origin.orientation.z = q[2]
    resp.map.info.origin.orientation.w = q[3]

    test_pxl = pix[0,0]
    if isinstance(test_pxl, (list, tuple)):
      is_multi_layer = True
      num_layers = len(test_pxl)
    else:
      is_multi_layer = False
      num_layers = 1

    resp.map.data = [None] * image_size[0] * image_size[1]
    for j in range(image_size[1]):
      for i in range(image_size[0]):

        pxl = pix[i, j]

        if is_multi_layer:
          color_average = sum(pxl) / num_layers
        else:
          color_average = pxl

        if negate:
          occ = color_average / 255.0;
        else:
          occ = (255 - color_average) / 255.0;

        map_idx = resp.map.info.width * (resp.map.info.height - j - 1) + i; 
        if (occ > occupied_thresh):
          resp.map.data[map_idx] = 100
        elif (occ < free_thresh):
          resp.map.data[map_idx] = 0
        else:
          resp.map.data[map_idx] = -1

    return resp

if __name__ == '__main__':
  MapServer()
